{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify the directory where the model files are stored\n",
    "BERT_BASE_DIR = '/path/to/mecab-ipadic-bpe-32k/do-whole-word-mask'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I1106 11:10:32.632133 4675833280 file_utils.py:32] TensorFlow version 2.0.0 available.\n",
      "I1106 11:10:32.632836 4675833280 file_utils.py:39] PyTorch version 1.3.0.post2 available.\n",
      "I1106 11:10:32.935988 4675833280 modeling_xlnet.py:194] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import BertForMaskedLM\n",
    "from tokenization import MecabBertTokenizer, MecabCharacterBertTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I1106 11:10:35.046508 4675833280 configuration_utils.py:148] loading configuration file /Users/m-suzuki/work/japanese-bert/jawiki-20190901/mecab-ipadic-bpe-32k/do-whole-word-mask/config.json\n",
      "I1106 11:10:35.047818 4675833280 configuration_utils.py:168] Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"finetuning_task\": null,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"num_labels\": 2,\n",
      "  \"output_attentions\": false,\n",
      "  \"output_hidden_states\": false,\n",
      "  \"output_past\": true,\n",
      "  \"pruned_heads\": {},\n",
      "  \"torchscript\": false,\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_bfloat16\": false,\n",
      "  \"vocab_size\": 32000\n",
      "}\n",
      "\n",
      "I1106 11:10:35.049305 4675833280 modeling_utils.py:334] loading weights file /Users/m-suzuki/work/japanese-bert/jawiki-20190901/mecab-ipadic-bpe-32k/do-whole-word-mask/pytorch_model.bin\n",
      "I1106 11:10:37.441706 4675833280 modeling_utils.py:408] Weights from pretrained model not used in BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n"
     ]
    }
   ],
   "source": [
    "model = BertForMaskedLM.from_pretrained(BERT_BASE_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = MecabBertTokenizer(vocab_file=f'{BERT_BASE_DIR}/vocab.txt')\n",
    "# Use MecabCharacterBertTokenizer instead for char-4k models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = '朝食に[MASK]を焼いて食べました。'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_ids = tokenizer.encode(text, add_special_tokens=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[2, 25965, 7, 4, 11, 16878, 16, 2949, 3913, 10, 8, 3]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "token_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = tokenizer.convert_ids_to_tokens(token_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['[CLS]', '朝食', 'に', '[MASK]', 'を', '焼い', 'て', '食べ', 'まし', 'た', '。', '[SEP]']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_ids = torch.tensor([token_ids])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[    2, 25965,     7,     4,    11, 16878,    16,  2949,  3913,    10,\n",
       "             8,     3]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "token_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['[CLS]'] ['た', '」', '朝食', 'て', 'お', '、', 'まし', 'です', '朝', '。']\n",
      "['朝食'] ['朝食', '朝', '夕食', '早朝', '最初', '午後', '昼', '代わり', '食事', '夕方']\n",
      "['に'] ['に', 'は', 'として', 'の', '用', 'を', '中', '後', 'で', 'にかけて']\n",
      "['[MASK]'] ['パン', '肉', 'ご飯', '豚肉', 'ハム', '野菜', '牛肉', '[UNK]', 'ケーキ', 'バター']\n",
      "['を'] ['を', '##パン', 'は', '##焼き', 'に', '##肉', 'パン', '##を', '##ト', 'で']\n",
      "['焼い'] ['焼い', '焼く', '焼き', '作っ', 'し', '燃やし', '巻い', '使っ', '揚げ', '買っ']\n",
      "['て'] ['て', 'で', '##て', 'ながら', 'たら', 'を', 'から', 'って', 'た', 'に']\n",
      "['食べ'] ['食べ', 'い', '飲み', 'おり', '待ち', '食べる', '食', '食事', '見', 'もらい']\n",
      "['まし'] ['まし', 'でし', 'ましょ', 'ませ', 'ます', 'です', 'し', 'られ', 'て', 'でしょ']\n",
      "['た'] ['た', 'て', '」', 'ます', 'まし', 'し', 'たい', 'たら', 'ん', 'たり']\n",
      "['。'] ['。', '」', '!', 'が', '...。', '、', 'から', ':', '......', 'よ']\n",
      "['[SEP]'] ['。', 'て', '、', 'に', '(', 'た', 'は', '」', 'し', ')。']\n"
     ]
    }
   ],
   "source": [
    "predictions, = model(token_ids)\n",
    "_, top10_pred_ids = torch.topk(predictions, k=10, dim=2)\n",
    "\n",
    "for correct_id, pred_ids in zip(token_ids[0], top10_pred_ids[0]):\n",
    "    correct_token = tokenizer.convert_ids_to_tokens([correct_id.item()])\n",
    "    pred_tokens = tokenizer.convert_ids_to_tokens(pred_ids.tolist())\n",
    "    print(correct_token, pred_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
